SG Former论文学习笔记
m0_58433654:
mask = local_mask_value+global_mask_value # [N 56 56] [N 28 28] [N 14 14]
mask_1 = mask.view(B, H * W) # [N 3136] [N 784] [N 196]
mask_2 = mask.permute(0, 2, 1).reshape(B, H * W) # [N 3136] [N 784] [N 196]
mask = [mask_1, mask_2]
else:
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # [N 2 3136 32] [N 4 784 32] [N 8 196 32]
# mask [local_mask global_mask] local_mask [value index] value [B, H, W]
# use mask to fuse
mask_1, mask_2 = mask # [[N 3136],[N 3136]] [[N 784],[N 784]] [[N 196],[N 196]]
mask_sort1, mask_sort_index1 = torch.sort(mask_1, dim=1)
mask_sort2, mask_sort_index2 = torch.sort(mask_2, dim=1)
你好,请问求出global mask和local mask把他们加起来后为什么还要再分成mask1和mask2呀?他们俩只是顺序不一样,在下面进行sort之后他们俩就完全一样了呀,下面做两次融合的结果也就一样了,感觉应该直接把global和local mask传到下一个transformer然后再分别做融合。
|